{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4dc0e955-6f3b-44b9-9ae6-30c24df41037",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0231424683f6431f8f37385dcb8a3f51",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from peft import LoraConfig, TaskType, get_peft_model\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    '/root/autodl-tmp/Llama-2-7b-hf',\n",
    "    trust_remote_code=True\n",
    ")\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    '/root/autodl-tmp/OpenELM-3B-Instruct',\n",
    "    device_map=\"auto\",\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    trust_remote_code=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "efbcc990-5c57-4b39-8138-dbf0dd910b74",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_func(example):\n",
    "    MAX_LENGTH = 384\n",
    "    \n",
    "    instruction = tokenizer(f\"{example['en_instruction'] + example['en_input']}<sep>\", add_special_tokens=True)\n",
    "    response = tokenizer(f\"{example['en_output']}\", add_special_tokens=False)\n",
    "\n",
    "    # instruction = tokenizer(f\"{example['instruction'] + example['input']}<sep>\", add_special_tokens=True)\n",
    "    # response = tokenizer(f\"{example['output']}\", add_special_tokens=False)\n",
    "    \n",
    "    input_ids = instruction[\"input_ids\"] + response[\"input_ids\"] + [tokenizer.pad_token_id]\n",
    "    attention_mask = instruction[\"attention_mask\"] + response[\"attention_mask\"] + [1]\n",
    "    labels = [-100] * len(instruction[\"input_ids\"]) + response[\"input_ids\"] + [tokenizer.pad_token_id]  \n",
    "\n",
    "    \n",
    "    if len(input_ids) > MAX_LENGTH:  # 做一个截断\n",
    "        input_ids = input_ids[:MAX_LENGTH]\n",
    "        attention_mask = attention_mask[:MAX_LENGTH]\n",
    "        labels = labels[:MAX_LENGTH]\n",
    "    return {\n",
    "        \"input_ids\": input_ids,\n",
    "        \"attention_mask\": attention_mask,\n",
    "        \"labels\": labels\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "52dbb09b-63a6-4b34-9069-a9645d5d442e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "252ee681232a475e9fbdf14de5e0edcc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/52002 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input_ids', 'attention_mask', 'labels'],\n",
       "    num_rows: 52002\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import Dataset\n",
    "import pandas as pd\n",
    "\n",
    "import os\n",
    "import pandas as pd\n",
    "from datasets import Dataset\n",
    "\n",
    "# 定义要处理的文件夹路径\n",
    "folder_path = '/root/autodl-tmp/alpaca-chinese-dataset/data/'\n",
    "\n",
    "# 获取文件夹中所有的 JSON 文件名\n",
    "json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')]\n",
    "\n",
    "# 创建一个空的 DataFrame 用于存放所有数据\n",
    "combined_df = pd.DataFrame()\n",
    "\n",
    "# 遍历所有 JSON 文件并将它们合并到一个 DataFrame 中\n",
    "for json_file in json_files:\n",
    "    file_path = os.path.join(folder_path, json_file)\n",
    "    df = pd.read_json(file_path)\n",
    "    combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "\n",
    "# 将合并后的 DataFrame 转换为 Dataset\n",
    "ds = Dataset.from_pandas(combined_df)\n",
    "# df = pd.read_json('/root/alpaca-chinese-dataset/data/alpaca_chinese_part_0.json')\n",
    "# ds = Dataset.from_pandas(df)\n",
    "tokenized_id = ds.map(process_func, remove_columns=ds.column_names)\n",
    "tokenized_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "65629d7a-9dd2-4421-92b3-8aa404c5ff49",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM, \n",
    "    target_modules=['token_embeddings', \"qkv_proj\", \"out_proj\", \"proj_1\", \"proj_2\"],\n",
    "    inference_mode=False,\n",
    "    r=32, # Lora 秩\n",
    "    lora_alpha=32, # Lora alaph，具体作用参见 Lora 原理\n",
    "    lora_dropout=0.1 # Dropout 比例\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8906b62b-d611-42d6-8327-e4696a01540b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path='./OpenELM-3B-Instruct', revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=32, target_modules={'token_embeddings', 'proj_1', 'proj_2', 'out_proj', 'qkv_proj'}, lora_alpha=32, lora_dropout=0.1, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.enable_input_require_grads()\n",
    "peft_model = get_peft_model(model, config)\n",
    "peft_model.train()\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "431a64b2-9e1b-4f40-b1d4-8f445a4a1090",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 45,883,392 || all params: 3,082,530,816 || trainable%: 1.4885\n"
     ]
    }
   ],
   "source": [
    "peft_model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "efc94568-3f11-4aa8-8235-d45b4e73599a",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = TrainingArguments(\n",
    "    output_dir=\"/root/autodl-tmp/output/openelm_3B_lora\",\n",
    "    per_device_train_batch_size=8,\n",
    "    gradient_accumulation_steps=4,\n",
    "    logging_steps=100,\n",
    "    num_train_epochs=0.87,  # 为了快速掩饰，我们训练到约1200个iter作为测试，建议设为10个epochs\n",
    "    save_steps=600,\n",
    "    learning_rate=1e-4,\n",
    "    save_on_each_node=True,\n",
    "    gradient_checkpointing=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4cfcba9e-1165-4125-8624-b916de721a3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=peft_model,\n",
    "    args=args,\n",
    "    train_dataset=tokenized_id,\n",
    "    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72245bbf-5376-47ec-8cfd-4d4cb81bf63a",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f8cb8d4-8dad-4c03-b6ff-8c036d26e5e2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
